import argparse
from time import time
import math

import torch
import torch.nn as nn
import torch.optim
import torch.utils.data
import torchvision.transforms as transforms
import torchvision.datasets as datasets
import numpy as np


import src as models

import os

model_names = sorted(name for name in models.__dict__
                     if name.islower() and not name.startswith("_")
                     and callable(models.__dict__[name]))

DATASETS = {
    'cifar10': {
        'num_classes': 10,
        'img_size': 32,
        'mean': [0.4914, 0.4822, 0.4465],
        'std': [0.2470, 0.2435, 0.2616]
    },
    'cifar100': {
        'num_classes': 100,
        'img_size': 32,
        'mean': [0.5071, 0.4867, 0.4408],
        'std': [0.2675, 0.2565, 0.2761]
    }
}


def init_parser():
    parser = argparse.ArgumentParser(description='CIFAR10 quick evaluation script')

    # Data args
    #parser.add_argument('data', metavar='DIR',
    #                    help='path to dataset')

    parser.add_argument('--dataset',
                        type=str.lower,
                        choices=['cifar10', 'cifar100'],
                        default='cifar10')

    parser.add_argument('-j', '--workers', default=4, type=int, metavar='N',
                        help='number of data loading workers (default: 4)')

    parser.add_argument('--print-freq', default=10, type=int, metavar='N',
                        help='log frequency (by iteration)')

    parser.add_argument('--checkpoint-path',
                        type=str,
                        default='checkpoint.pth',
                        help='path to checkpoint (default: checkpoint.pth)')

    parser.add_argument('-b', '--batch-size', default=128, type=int,
                        metavar='N',
                        help='mini-batch size (default: 128)', dest='batch_size')

    parser.add_argument('-m', '--model',
                        type=str.lower,
                        choices=model_names,
                        default='cct_2', dest='model')

    parser.add_argument('-p', '--positional-embedding',
                        type=str.lower,
                        choices=['learnable', 'sine', 'none'],
                        default='learnable', dest='positional_embedding')

    parser.add_argument('--conv-layers', default=2, type=int,
                        help='number of convolutional layers (cct only)')

    parser.add_argument('--conv-size', default=3, type=int,
                        help='convolution kernel size (cct only)')

    parser.add_argument('--patch-size', default=4, type=int,
                        help='image patch size (vit and cvt only)')

    parser.add_argument('--gpu-id', default=0, type=int)

    parser.add_argument('--no-cuda', action='store_true',
                        help='disable cuda')

    parser.add_argument('--download', action='store_true',
                        help='download dataset (or verify if already downloaded)')

    return parser


def main():
    parser = init_parser()
    args = parser.parse_args()
    img_size = DATASETS[args.dataset]['img_size']
    num_classes = DATASETS[args.dataset]['num_classes']
    img_mean, img_std = DATASETS[args.dataset]['mean'], DATASETS[args.dataset]['std']

#    model = models.__dict__[args.model](img_size=img_size,
#                                        num_classes=num_classes,
#                                        positional_embedding=args.positional_embedding,
#                                        n_conv_layers=args.conv_layers,
#                                        kernel_size=args.conv_size,
#                                        patch_size=args.patch_size)

    #model.load_state_dict(torch.load(args.checkpoint_path, map_location='gpu'))
    
    folder = 'cifar10_models2/'
    models = [f for f in os.listdir(folder) if os.path.isfile(os.path.join(folder, f))]
    models = sorted(models)
    print(models)
    
    #_, test_loader = getData(name='cifar10', train_bs=128, test_bs=1024) 
 
    for index, m in enumerate(models):      
    
        model = torch.load(folder + m)
        print(m)

        
        #print("Loaded checkpoint.")
    
        normalize = [transforms.Normalize(mean=img_mean, std=img_std)]
    
        if (not args.no_cuda) and torch.cuda.is_available():
            torch.cuda.set_device(args.gpu_id)
            model.cuda(args.gpu_id)
    
        val_dataset = datasets.__dict__[args.dataset.upper()](
            root='./cifar/', train=False, download=args.download, transform=transforms.Compose([
                transforms.Resize(img_size),
                transforms.ToTensor(),
                *normalize,
            ]))
    
        val_loader = torch.utils.data.DataLoader(
            val_dataset,
            batch_size=args.batch_size, shuffle=False,
            num_workers=args.workers)
    
        #print("Beginning evaluation")
        model.eval()
        #print(model.training)
        time_begin = time()
        #acc1 = cls_validate(val_loader, model, args, time_begin=time_begin)
        #total_mins = (time() - time_begin) / 60
    
        #print("Beginning noisy evaluation")
        #acc1 = cls_noisy_validate(val_loader, model, args, time_begin=time_begin)
        acc1 = cls_sp_validate(val_loader, model, args)
        #print(f'Script finished in {total_mins:.2f} minutes, '
        #      f'final top-1: {acc1:.2f}')


def accuracy(output, target):
    with torch.no_grad():
        batch_size = target.size(0)

        _, pred = output.topk(1, 1, True, True)
        pred = pred.t()
        correct = pred.eq(target.view(1, -1).expand_as(pred))

        res = []
        correct_k = correct[:1].flatten().float().sum(0, keepdim=True)
        res.append(correct_k.mul_(100.0 / batch_size))
        return res


def cls_validate(val_loader, model, args, time_begin=None):
    model.eval()
    acc1_val = 0
    n = 0
    with torch.no_grad():
        for i, (images, target) in enumerate(val_loader):
            if (not args.no_cuda) and torch.cuda.is_available():
                images = images.cuda(args.gpu_id, non_blocking=True)
                target = target.cuda(args.gpu_id, non_blocking=True)

            output = model(images)

            acc1 = accuracy(output, target)
            n += images.size(0)
            acc1_val += float(acc1[0] * images.size(0))

            if args.print_freq >= 0 and i % args.print_freq == 0:
                avg_acc1 = (acc1_val / n)
                #print(f'[Eval][{i}] \t Top-1 {avg_acc1:6.2f}')

    avg_acc1 = (acc1_val / n)
    total_mins = -1 if time_begin is None else (time() - time_begin) / 60
    #print(f'[Final]\t \t Top-1 {avg_acc1:6.2f} \t \t Time: {total_mins:.2f}')

    return avg_acc1


def cls_noisy_validate(val_loader, model, args, time_begin=None):
    
    perturbed_test_accs = []
    for eps in [0, 0.02, 0.04, 0.06, 0.08, 0.1, 0.12, 0.15, 0.18, 0.2, 0.25, 0.3, 0.35, 0.4, 0.45, 0.5]:    
        model.eval()
        acc1_val = 0
        n = 0
        with torch.no_grad():
            for i, (images, target) in enumerate(val_loader):
                if (not args.no_cuda) and torch.cuda.is_available():
                    images = images.cuda(args.gpu_id, non_blocking=True)
                    target = target.cuda(args.gpu_id, non_blocking=True)
    
    
                images += eps * torch.randn(images.shape).float().to('cuda')

    
                output = model(images)
    
                acc1 = accuracy(output, target)
                n += images.size(0)
                acc1_val += float(acc1[0] * images.size(0))
    
                if args.print_freq >= 0 and i % args.print_freq == 0:
                    avg_acc1 = (acc1_val / n)
                    #print(f'[Eval][{i}] \t Top-1 {avg_acc1:6.2f}')
    
        avg_acc1 = (acc1_val / n)
        total_mins = -1 if time_begin is None else (time() - time_begin) / 60
        #print(f'[Final]\t \t Top-1 {avg_acc1:6.2f} \t \t Time: {total_mins:.2f}')
        perturbed_test_accs.append(avg_acc1)


    print(perturbed_test_accs)
    return perturbed_test_accs



def sp(image, amount):
      row,col = image.shape
      s_vs_p = 0.5
      out = np.copy(image)
      # Salt mode
      num_salt = np.ceil(amount * image.size * s_vs_p)
      idx = np.random.choice(range(32*32), np.int(num_salt), False)
      out = out.reshape(image.size, -1)
      out[idx] = np.min(out)
      out = out.reshape(32,32)
      
      # Pepper mode
      num_pepper = np.ceil(amount * image.size * (1. - s_vs_p))
      idx = np.random.choice(range(32*32), np.int(num_pepper), False)
      out = out.reshape(image.size, -1)
      out[idx] = np.max(out)
      out = out.reshape(32,32)
      return out
  
def sp_wrapper(data, amount):
    np.random.seed(12345)
    for i in range(data.shape[0]):
        data_numpy = data[i,0,:,:].data.cpu().numpy()
        noisy_input = sp(data_numpy, amount)
        data[i,0,:,:] = torch.tensor(noisy_input).float().to('cuda')

    return data    

def cls_sp_validate(val_loader, model, time_begin=None):
    
    perturbed_test_accs = []
    for eps in [0, 0.02, 0.04, 0.06, 0.08, 0.1, 0.12, 0.15, 0.18, 0.2, 0.25, 0.3, 0.35]:    
        model.eval()
        acc1_val = 0
        n = 0
        with torch.no_grad():
            for i, (images, target) in enumerate(val_loader):
                    images = images.cuda(non_blocking=True)
                    target = target.cuda(non_blocking=True)
    
    
                    images = sp_wrapper(images, eps)
                    output = model(images)
        
                    model_logits = output[0] if (type(output) is tuple) else output
                    pred = model_logits.data.max(1, keepdim=True)[1] # get the index of the max log-probability
                    acc1_val += pred.eq(target.data.view_as(pred)).cpu().sum().item()
                    n += len(images)
    
        avg_acc1 = (acc1_val / n)
        perturbed_test_accs.append(avg_acc1)

    print(perturbed_test_accs)
    return perturbed_test_accs


if __name__ == '__main__':
    main()
